module Data.Matrix.SmithNormalForm.Internal
    ( smithNF
    , rectifyDiagonal
    , diagonalize
    , isDiagonalMatrix
    , divides
    ) where

import qualified Data.Matrix as M
import qualified Data.Vector as V
import Data.Maybe (fromJust)

-- | Main method that returns the Smith normal form of a given matrix.
smithNF :: Integral a => M.Matrix a -> M.Matrix a
smithNF m = (\diags -> extraZeros (M.diagonalList (length diags) 0 diags)) $ map abs $ rectifyDiagonal $ diagonalize m
  where extraZeros d = if M.nrows m > M.ncols m
                       then (M.<->) d (M.zero ((M.nrows m) - (M.ncols m)) (M.ncols m))
                       else (M.<|>) d (M.zero (M.nrows m) ((M.ncols m) - (M.nrows m)))

-- | Given a diagonal matrix, outputs a list \([d_1,..,d_n]\) that satisfies 
--  \(d_k \mid d_{k+1}\) and represents the diagonal entries.
-- Assumes input is a diagonal matrix (not checked).
rectifyDiagonal :: Integral a => M.Matrix a -> [a]
rectifyDiagonal diagonalMatrix
  | length diags <= 1 = diags
  | allDivisible = diags
  | otherwise = rectifyDiagonal $ diagonalize $ (\op -> op (M.diagonalList (length diags) 0 diags)) $ head $ map modifier $ filter (\(b, _) -> b == False) $ zip divisibles divIndices 
  where diags = V.toList (M.getDiag diagonalMatrix)
        divIndices = zip [0..(length diags) - 2] [1..(length diags) - 1]
        divPairs = map (\(k, k') -> (diags !! k, diags !! k')) $ divIndices
        divisibles = map (\(d, d') -> d `divides` d') divPairs
        allDivisible = and $ divisibles
        modifier (b, (i, j)) = if b then id else M.setElem (diags !! j) (i+2, i+1)


-- | Given a matrix, returns a diagonal matrix obtained by applying 
-- elementary row and column operations, but which does not necessarily satisfy the divisibility property
diagonalize :: Integral a => M.Matrix a -> M.Matrix a
diagonalize = diagonalizer 1

diagonalizer :: Integral a => Int -> M.Matrix a -> M.Matrix a
diagonalizer rowIndex m 
  | rowIndex > M.nrows m = m
  | isDiagonalMatrix m = m
  | pivotPosition == (-1, -1) = m -- means there's no more cols
  | hasNonzeroAmongZeros areZeroCols = diagonalizer rowIndex $ (\(zeroColIndex, _) -> M.switchCols zeroColIndex (zeroColIndex+1) m) $ head $ dropWhile (\(_, b) -> not b) $ zip [1..] areZeroCols 
  | isZero restOfRow && isZero restOfCol = diagonalizer (rowIndex + 1) m
  | otherwise = diagonalizer rowIndex (clearPivotRow pivotPosition $ improvePivot pivotPosition (i + 1) mat)
  where areZeroCols = map isZero $ cols m
        (pivotPosition, mat) = choosePivot rowIndex m 
        (i, j) = pivotPosition
        restOfRow = M.rowVector $ V.drop j $ M.getRow i m
        restOfCol = M.colVector $ V.drop i $ M.getCol j m

hasNonzeroAmongZeros :: [Bool] -> Bool
hasNonzeroAmongZeros (a:b:xs) = (a && (not b)) || (hasNonzeroAmongZeros (b:xs))
hasNonzeroAmongZeros _ = False

-- | Returns whether a matrix (not necessarily square) is diagonal
isDiagonalMatrix :: (Num a, Eq a) => M.Matrix a -> Bool
isDiagonalMatrix m = and $ map (== 0) [M.getElem i j m | i <- [1..M.nrows m], j <- [1..M.ncols m], i /= j]

clearPivotRow :: Integral a => (Int, Int) -> M.Matrix a -> M.Matrix a
clearPivotRow (t, jt) m = M.transpose $ improvePivot (jt, t) (jt+1) (M.transpose m)

-- INVARIANT: all operations should not change the absolute value of the determinant
-- i.e. (I) scale a row by a unit
--     (II) switch two rows
--     (III) add a multiple of one row to another
-- This is does most of the real work.
improvePivot :: Integral a => (Int, Int) -> Int -> M.Matrix a -> M.Matrix a
improvePivot (t, jt) rowIndex m
  | pivot == 0 = error "Zero pivot entry."  
  | rowIndex > M.nrows m = m
  | pivot `divides` nextEntry = improvePivot (t, jt) (rowIndex + 1) (M.mapRow (\j elt -> elt + ((M.getElem 1 j pivotRow)*(-1)*(nextEntry `div` pivot))) rowIndex m)
  | otherwise = improvePivot (t, jt) rowIndex (M.switchRows t rowIndex (M.mapRow (\j elt -> elt - (nextEntry `div` pivot)*(M.getElem 1 j pivotRow)) rowIndex m))        -- idea: nextEntry = pivot * q + r
  where pivot = M.getElem t jt m
        nextEntry = M.getElem rowIndex jt m   
        pivotRow = row t m

-- | Returns whether a `divides` b := \(a \mid b\) 
-- and handles the special case of \(a=0\)
divides :: Integral a => a -> a -> Bool
divides 0 0 = True
divides 0 _ = False
divides a b = b `mod` a == 0

---------------------
-- PIVOT-SELECTION --
---------------------
choosePivot :: Integral a => Int -> M.Matrix a -> ((Int, Int), M.Matrix a)
choosePivot rowIndex m
  | pivotPosition == Nothing = ((-1, -1), m)    
  | (M.getElem t jt m) /= 0 = ((t, jt), m)
  | otherwise = ((t, jt), makePivotNonzero (t, jt) m)
  where pivotPosition = nextPivotPosition rowIndex m
        (t, jt) = fromJust pivotPosition

-- ASSUME: there is a nonzero entry in column jt, not at (t, jt), but at jt' > jt
makePivotNonzero :: Integral a => (Int, Int) -> M.Matrix a -> M.Matrix a
makePivotNonzero (t, jt) m = M.switchRows t nonzeroIndex m
  where nonzeroIndex = head [i | (i, entry) <- zip [1..] (M.toList (col jt m)), entry /= 0, i > t]

-- ASSUMES: the rows with index < rowIndex only have a single nonzero entry, 
--   which occurs in columnIndex < rowIndex 
nextPivotPosition :: Integral a => Int -> M.Matrix a -> Maybe (Int, Int)
nextPivotPosition rowIndex m 
  | null nonzeroColIndices = Nothing
  | otherwise = Just (rowIndex, (head nonzeroColIndices))
  where nonzeroColIndices = [j | (j, column) <- zip [1..] (cols m), not (isZero column), j >= rowIndex]

isZero :: (Num a, Eq a) => M.Matrix a -> Bool
isZero m = m == M.zero (M.nrows m) (M.ncols m)

---------------------------
-- MATRIX HELPER METHODS --
---------------------------
-- return a list of rows, as 1 x n matrices
rows :: M.Matrix a -> [M.Matrix a]
rows m = map (\k -> row k m) [1..(M.nrows m)]

-- return a list of cols, as n x 1 matrices
cols :: M.Matrix a -> [M.Matrix a]
cols m = map (\k -> col k m) [1..(M.ncols m)]

-- get a row, represented as an 1 x n matrix
row :: Int -> M.Matrix a -> M.Matrix a
row k = (M.rowVector . M.getRow k) 

-- get a column, represented as a n x 1 matrix
col :: Int -> M.Matrix a -> M.Matrix a
col k = (M.colVector . M.getCol k)
